Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Operator] Add Conv3d forward function #412

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

Gxiandy
Copy link
Contributor

@Gxiandy Gxiandy commented Jan 12, 2025

PR Category

Operator

Type of Change

New Feature

Description

Add Conv3d forward function and related tests

Issue

Progress

  • Change is properly reviewed (1 reviewer required, 2 recommended).
  • Change is responded to an issue.
  • Change is fully covered by a UT.

Performance

Operator: conv3d Performance Test (dtype=torch.float16, mode=cuda,level=core)

Status Torch Latency (ms) Gems Latency (ms) Gems Speedup Size Detail
SUCCESS 6.790144 6.793216 1.000 {'input': torch.Size([104, 16, 32, 32, 32]), 'weight': torch.Size([32, 16, 4, 4, 4]), 'bias': None, 'groups': 1, 'stride': 1, 'padding': 0}
SUCCESS 4.712448 4.705280 1.002 {'input': torch.Size([64, 32, 18, 180, 18]), 'weight': torch.Size([32, 32, 5, 5, 5]), 'bias': None, 'groups': 1, 'stride': 2, 'padding': 1}
SUCCESS 0.943104 0.941056 1.002 {'input': torch.Size([4, 32, 110, 110, 10]), 'weight': torch.Size([64, 32, 5, 5, 5]), 'bias': None, 'groups': 1, 'stride': 2, 'padding': 1}
SUCCESS 2.338816 2.342912 0.998 {'input': torch.Size([4, 64, 110, 110, 10]), 'weight': torch.Size([16, 64, 5, 5, 5]), 'bias': None, 'groups': 1, 'stride': 2, 'padding': 1}
SUCCESS 0.208896 0.208896 1.000 {'input': torch.Size([16, 32, 120, 12, 12]), 'weight': torch.Size([24, 32, 3, 3, 3]), 'bias': None, 'groups': 1, 'stride': 2, 'padding': 1}
SUCCESS 7.403520 7.422976 0.997 {'input': torch.Size([16, 32, 240, 24, 24]), 'weight': torch.Size([24, 16, 3, 3, 3]), 'bias': None, 'groups': 2, 'stride': 1, 'padding': 1}
SUCCESS 0.224256 0.224256 1.000 {'input': torch.Size([16, 32, 24, 24, 24]), 'weight': torch.Size([24, 16, 3, 3, 3]), 'bias': None, 'groups': 2, 'stride': 2, 'padding': 2}
SUCCESS 0.933888 0.931840 1.002 {'input': torch.Size([16, 32, 24, 24, 24]), 'weight': torch.Size([24, 16, 3, 3, 3]), 'bias': None, 'groups': 2, 'stride': 1, 'padding': 2}

Operator: conv3d Performance Test (dtype=torch.float32, mode=cuda,level=core)

Status Torch Latency (ms ) Gems Latency (ms) Gems Speedup Size Detail
SUCCESS 35.387390 35.483646 0.997 {'input': torch.Size([104, 16, 32, 32, 32]), 'weight': torch.Size([32, 16, 4, 4, 4]), 'bias': None, 'groups': 1, 'stride': 1, 'padding': 0}
SUCCESS 20.566015 21.796864 0.944 {'input': torch.Size([64, 32, 18, 180, 18]), 'weight': torch.Size([32, 32, 5, 5, 5]), 'bias': None, 'groups': 1, 'stride': 2, 'padding': 1}
SUCCESS 5.250048 5.249024 1.000 {'input': torch.Size([4, 32, 110, 110, 10]), 'weight': torch.Size([64, 32, 5, 5, 5]), 'bias': None, 'groups': 1, 'stride': 2, 'padding': 1}
SUCCESS 5.134336 5.169152 0.993 {'input': torch.Size([4, 64, 110, 110, 10]), 'weight': torch.Size([16, 64, 5, 5, 5]), 'bias': None, 'groups': 1, 'stride': 2, 'padding': 1}
SUCCESS 0.425984 0.431104 0.988 {'input': torch.Size([16, 32, 120, 12, 12]), 'weight': torch.Size([24, 32, 3, 3, 3]), 'bias': None, 'groups': 1, 'stride': 2, 'padding': 1}
SUCCESS 103.790588 103.080963 1.007 {'input': torch.Size([16, 32, 240, 24, 24]), 'weight': torch.Size([24, 16, 3, 3, 3]), 'bias': None, 'groups': 2, 'stride': 1, 'padding': 1}
SUCCESS 0.407552 0.414720 0.983 {'input': torch.Size([16, 32, 24, 24, 24]), 'weight': torch.Size([24, 16, 3, 3, 3]), 'bias': None, 'groups': 2, 'stride': 2, 'padding': 2}
SUCCESS 10.667008 10.584064 1.008 {'input': torch.Size([16, 32, 24, 24, 24]), 'weight': torch.Size([24, 16, 3, 3, 3]), 'bias': None, 'groups': 2, 'stride': 1, 'padding': 2}

Operator: conv3d Performance Test (dtype=torch.bfloat16, mode=cuda,level=core)

Status Torch Latency (ms) Gems Latency (ms) Gems Speedup Size Detail
SUCCESS 11.389952 11.388928 1.000 {'input': torch.Size([104, 16, 32, 32, 32]), 'weight': torch.Size([32, 16, 4, 4, 4]), 'bias': None, 'groups': 1, 'stride': 1, 'padding': 0}
SUCCESS 4.552704 4.488192 1.014 {'input': torch.Size([64, 32, 18, 180, 18]), 'weight': torch.Size([32, 32, 5, 5, 5]), 'bias': None, 'groups': 1, 'stride': 2, 'padding': 1}
SUCCESS 0.941056 0.943104 0.998 {'input': torch.Size([4, 32, 110, 110, 10]), 'weight': torch.Size([64, 32, 5, 5, 5]), 'bias': None, 'groups': 1, 'stride': 2, 'padding': 1}
SUCCESS 3.181568 3.180544 1.000 {'input': torch.Size([4, 64, 110, 110, 10]), 'weight': torch.Size([16, 64, 5, 5, 5]), 'bias': None, 'groups': 1, 'stride': 2, 'padding': 1}
SUCCESS 0.256000 0.256000 1.000 {'input': torch.Size([16, 32, 120, 12, 12]), 'weight': torch.Size([24, 32, 3, 3, 3]), 'bias': None, 'groups': 1, 'stride': 2, 'padding': 1}
SUCCESS 7.410688 7.752704 0.956 {'input': torch.Size([16, 32, 240, 24, 24]), 'weight': torch.Size([24, 16, 3, 3, 3]), 'bias': None, 'groups': 2, 'stride': 1, 'padding': 1}
SUCCESS 0.222208 0.223232 0.995 {'input': torch.Size([16, 32, 24, 24, 24]), 'weight': torch.Size([24, 16, 3, 3, 3]), 'bias': None, 'groups': 2, 'stride': 2, 'padding': 2}
SUCCESS 0.931840 0.934912 0.997 {'input': torch.Size([16, 32, 24, 24, 24]), 'weight': torch.Size([24, 16, 3, 3, 3]), 'bias': None, 'groups': 2, 'stride': 1, 'padding': 2}

@StrongSpoon StrongSpoon self-requested a review January 13, 2025 02:18
@StrongSpoon
Copy link
Collaborator

impressive performance! we will review soon.

Copy link
Collaborator

@StrongSpoon StrongSpoon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

developing conv3d is quite hard. you could refer to the implementation of conv2d and figure out a better way. if you are still confused, welcome to contact us in community by wechat.

- [16, 32, 120, 12, 12, 24, 3, 3, 3, 2, 1, 1]
- [16, 32, 240, 24, 24, 24, 3, 3, 3, 1, 1, 2]
- [16, 32, 24, 24, 24, 24, 3, 3, 3, 2, 2, 2]
- [16, 32, 24, 24, 24, 24, 3, 3, 3, 1, 2, 2]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

usually we suggest setting 5 shapes for core mode.

for k_w in [32 * i for i in range(1, 4)]
for stride in [1, (2, 2, 2), (3, 3, 3)]
for padding in [0, (1, 1, 1), (0, 1, 2)]
for groups in [1, 2, 4, 8]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

personally I think the number of test cases is too large. pick 20 shapes from classic networks and the performance data is convincing enough.

bench = Conv3dBenchmark(
input_fn=conv3d_input_fn,
op_name="conv3d",
torch_op=torch.nn.functional.conv3d,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since your implementation of conv3d is not registered into aten library, the function called in benchmark is still torch op. please add registration in src/flag_gems/init.py and update the benchmark results.

@@ -264,7 +264,7 @@ def set_more_shapes(self):

@pytest.mark.conv2d
def test_perf_conv2d():
def conv2d_input_fn(shape, dtype, device):
def conv3d_input_fn(shape, dtype, device):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's not time to enable benchmark of conv2d right now. so does conv3d. I'll mark it as skip.

shape_input, dtype=dtype, device=flag_gems.device, requires_grad=True
)
ref_inp = to_reference(inp, True)
torch.backends.cudnn.allow_tf32 = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

usually we set allow_tf32 as False since the precision of tf32 is not satisfying.

stride=strides,
padding=paddings,
dilation=dilations,
).to(dtype)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gems_assert_close will cast the reference tensor to dtype. you don't need to do this again.

for t in range(T):
for r in range(R):
for s in range(S):
for c in range(C_per_group):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

using multiple layers of loop might not be a good idea to compute convolution. try loading tensors with high-dimension indexes and using tl.dot primitive.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants